Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Oct 8, 2025

Description

Breaks up the large cast_kernels.cuh and cast_gated_kernels.cuh into smaller headers organized by scaling mode.
No functional or behavior changes: code is moved, not modified. This improves structure, readability, and maintainability (easier to navigate/extend specific scaling paths). Build includes/exports updated accordingly; tests unaffected.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Broke up the large cast_kernels.cuh and cast_gated_kernels.cuh into smaller headers organized by scaling mode.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov requested a review from ptrendx October 8, 2025 13:43
@Oleg-Goncharov Oleg-Goncharov changed the title [common] Refactor: split cast/gated kernels by scaling mode [common] Split cast/gated kernels by scaling mode Oct 8, 2025
@ptrendx ptrendx requested a review from Copilot October 9, 2025 16:03
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request refactors the large cast_kernels.cuh and cast_gated_kernels.cuh files into smaller, more organized header files structured by scaling mode. This improves code maintainability, readability, and navigation by creating specialized headers for different quantization and scaling implementations.

  • Breaks down monolithic headers into focused, scaling-mode-specific files
  • Reorganizes code structure without modifying functionality or behavior
  • Creates dispatcher files to coordinate between different scaling implementations

Reviewed Changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
transformer_engine/common/util/cast_kernels.cuh Removed all content - entire file deleted as part of refactoring
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh NVFP4 quantize with transpose functionality, updated file path and namespacing
transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh New file containing NVFP4-specific quantization kernels
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh New file containing NVFP4 dequantization functionality
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh New file with core NVFP4 utility functions and device operations
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh New file containing MXFP8 quantization kernels
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh MXFP8 gated operations, significantly reduced from original gated kernels file
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh New file containing MXFP8 dequantization functionality
transformer_engine/common/cast/fp8/quantize_fp8.cuh New file containing FP8 quantization kernels
transformer_engine/common/cast/fp8/gated_fp8.cuh New file containing FP8 gated operations
transformer_engine/common/cast/fp8/dequantize_fp8.cuh New file containing FP8 dequantization functionality
transformer_engine/common/cast/dispatch/quantize.cuh New dispatcher file coordinating quantization across scaling modes
transformer_engine/common/cast/dispatch/gated.cuh New dispatcher file coordinating gated operations across scaling modes
transformer_engine/common/cast/dispatch/dequantize.cuh New dispatcher file coordinating dequantization across scaling modes

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@ptrendx ptrendx requested a review from timmoon10 October 9, 2025 17:20
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually support fused activation-cast kernels for NVFP4? If not, we should remove these template arguments so that we don't compile unnecessary kernels and so we prevent users from accidentally calling them. We should also remove them from the kernel, and modify quantize_helper so it errors out if you attempt something invalid.

Suggested change
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally left activation template arguments and all the activation related logic untouched, so we can easily enable it when/if it becomes the part of the FP4 recipe.
@ptrendx, should we keep it, or I just go ahead and clean up the kernel?
I also didn't want to add any functionality related modifications to this PR to not overwhelm it, and to do it separately in a following PRs. Since there are some parts of the NVFP4 code that need to be reviewed/changed anyways

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't support them, we should at least error out if you attempt to run them. Avoiding unnecessary compilations would also be useful so we don't blow up compile time and binary size.

I'm fine deferring this if we want this PR to minimize functional changes, but we should aim to catch more of these errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timmoon10 @Oleg-Goncharov Let's minimize changes in this PR and just do the code movement here. Otherwise it will be very hard to properly review if the functionality was not altered.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM once we iron out the test failures.

@Oleg-Goncharov Oleg-Goncharov changed the title [common] Split cast/gated kernels by scaling mode [Common] Split cast/gated kernels by scaling mode Oct 16, 2025
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_cast_kernels_cleanup branch from 9202e6d to 86dd987 Compare October 24, 2025 20:20
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest update applies formatting-only changes to transformer_engine/common/cast/nvfp4/core_nvfp4.cuh, aligning the file with the project's clang-format configuration (Google-based style, 100-char column limit, 2-space indentation). No functional or behavioral modifications were made—function signatures and error macros were reformatted to improve consistency and readability. This change ensures that the NVFP4 core utilities, which handle FP4 quantization and conversion operations via inline PTX assembly, adhere to the repository's established formatting standards.

Important Files Changed

Filename Score Overview
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh 5/5 Formatting-only changes: function signatures and error macros reformatted to match clang-format style; no functional modifications.

Confidence score: 5/5

  • This PR update is safe to merge with minimal risk, as it contains only formatting changes with no functional modifications.
  • Score reflects that the changes are purely cosmetic (clang-format enforcement) and cannot introduce bugs, regressions, or behavioral changes; all function logic remains identical.
  • No files require special attention; this is a straightforward formatting pass to ensure style consistency across the NVFP4 core utilities.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. This update addresses formatting inconsistencies in core_nvfp4.cuh by reformatting function signatures to comply with the project's .clang-format style guide (Google-based, 100-character column limit). The changes are purely cosmetic—multi-line function signatures like compute_decoding_scaling_factor and mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding are now consistently split across lines with proper indentation, while compute_global_encode_scaling_factor_FP4 is collapsed to a single line. Additionally, the #else branches that threw errors when FP4_TYPE_SUPPORTED is undefined have been removed, simplifying the code structure. This refactoring is part of the broader PR goal to split large cast kernel headers into smaller, more maintainable files organized by scaling mode. The reformatting improves readability and navigation within device code, aligning with the project's style enforcement strategy (cpplint, clang-format, pre-commit hooks).

Important Files Changed

Filename Score Overview
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh 5/5 Formatting-only changes: function signatures reformatted for readability and #else error branches removed.

Confidence score: 5/5

  • This PR is safe to merge with minimal risk as the changes are purely cosmetic and do not modify any logic.
  • Score reflects formatting-only changes with no impact on compiled code or behavior; all modifications align with project style guidelines.
  • No files require special attention; this is a straightforward formatting cleanup.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

}

const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) &&
is_supported_by_CC_100();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also really don't like the name of this function, but let's fix that in the

Comment on lines 81 to 82
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", input.data.shape);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error is potentially misleading since it is possible that the reason for use_tma_kernels to be false is lack of rowwise or columnwise output not being there. Also this is actually wrong since I believe the mxfp8 kernels can support only 1 of those outputs, right? Looking at it the second time it seems that the logic of setting the variables for rowwise and columnwise output is convoluted and not at all understandable. I will make a comment there.

Comment on lines +458 to +461
namespace detail {
using Empty = transformer_engine::Empty;
__device__ inline float identity(float value, const Empty &) { return value; }
} // namespace detail
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is currently used by the CastVectorizedUnaryKernelLauncher and CastVectorizedUnaryGradKernelLauncher below

Oleg-Goncharov and others added 17 commits October 28, 2025 14:40
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
…s from the NVFP4 transpose test suite

Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Oleg-Goncharov and others added 4 commits October 28, 2025 14:40
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_cast_kernels_cleanup branch from c0f4a1e to b764dea Compare October 28, 2025 17:05
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR.

This update adds a new return_max_logit feature to fused attention APIs across PyTorch, JAX, and C++ layers, while completing the cast kernel refactoring. The attention API now optionally returns the maximum attention logit per head (useful for Muon optimizer integration and numerical stability analysis). The change threads a boolean parameter through all attention forward paths and disables FP8/F16_max512 backends when return_max_logit=true since only F16_arbitrary_seqlen supports it. Additionally, the nvidia-mathdx dependency was removed from build configurations, sigmoidf was added to math.h, submodule pointers were updated, and several utility refactors (RNG API simplification, tensor creation helpers) were applied to align with the reorganized kernel structure. The bulk of the cast kernel split (FP8, MXFP8, NVFP4 into separate headers) was completed in earlier commits and is not repeated here.

PR Description Notes:

  • The PR description states "No functional or behavior changes: code is moved, not modified," but this review includes the return_max_logit feature addition, which is a functional change. The description should be updated to reflect this.

Important Files Changed

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Adds return_max_logit parameter to C API; disables FP8/F16_max512 backends when enabled
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 4/5 Implements separate Max/Sum_Exp tensor allocation when return_max_logit=true
transformer_engine/pytorch/attention/dot_product_attention/backends.py 2/5 Adds max_logit support but has inconsistent FP8 handling and unverified unpacking logic
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 3/5 Integrates max_logit with context-parallel attention; requires careful None-checking across all code paths
transformer_engine/jax/csrc/extensions/attention.cpp 2/5 Critical bug: hardcodes false for deterministic parameter instead of propagating actual value
pyproject.toml 3/5 Removes nvidia-mathdx build dependency without documentation; may break builds if code depends on it
transformer_engine/common/util/curanddx.hpp 4.5/5 New Philox4x32 RNG implementation extracted during refactor; self-contained and correct
transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh 3/5 Potential boundary check error at line 128; macro instead of constexpr at line 49
transformer_engine/common/cast/fp8/quantize_fp8.cuh 3/5 Incorrect comments claiming elt is 0 when no activation applied (lines 168, 178)
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh 3.5/5 Uninitialized scaling_type variable if neither rowwise nor colwise scaling enabled

Confidence score: 3/5

  • This PR introduces functional changes (max_logit feature) beyond the stated "code is moved, not modified" scope, and includes critical bugs in the JAX attention extension that will silently ignore user-specified deterministic behavior.
  • Score lowered due to: (1) JAX extension bug where deterministic parameter is hardcoded to false, (2) inconsistent FP8 handling in PyTorch backends where max_logit is initialized but never populated, (3) unverified tuple unpacking logic that assumes fused_attn_fwd return structure without clear guarantees, (4) potential uninitialized variables in MXFP8/NVFP4 kernels, and (5) missing documentation for nvidia-mathdx removal.
  • Pay close attention to transformer_engine/jax/csrc/extensions/attention.cpp (deterministic parameter bug), transformer_engine/pytorch/attention/dot_product_attention/backends.py (FP8 max_logit handling), transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py (None-checking for max_logit operations), and the NVFP4/MXFP8 quantization kernels with potential uninitialized variables.

Additional Comments (17)

  1. transformer_engine/common/util/curanddx.hpp, line 19-23 (link)

    style: pointer parameter alignment doesn't match .clang-format (should be unsigned int*unsigned int *)

  2. setup.py, line 165-174 (link)

    style: assertion will fail if a submodule is at a different commit than expected (starts with '+'); consider handling this case explicitly or documenting this behavior more clearly in the error message

  3. setup.py, line 180-181 (link)

    style: silently returns on any subprocess error; consider logging the exception or at least distinguishing between expected failures (not a git repo) vs. unexpected errors

  4. transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh, line 690-697 (link)

    logic: scaling_type is used uninitialized if neither rowwise nor columnwise scaling is enabled. Add else branch with NVTE_CHECK or default initialization.

  5. transformer_engine/pytorch/cpp_extensions/fused_attn.py, line 331 (link)

    style: amax_dims logic assumes thd vs bshd/sbhd are the only possibilities; if new layouts are added, this could silently produce incorrect max_logit shape

  6. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 697 (link)

    logic: unpacking operator *max_logit expects zero or one element, but the return_max_logit flag determines whether fused_attn_fwd returns it—check that the unpacking consistently handles both cases. Does fused_attn_fwd always return a tuple with *max_logit in the correct position when return_max_logit=True, and does it skip that return value entirely when return_max_logit=False?

  7. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1164-1165 (link)

    logic: initializes max_logit_per_step and max_logit to None but later indexes them—verify that indexing/assignment only happens after return_max_logit=True initialization

  8. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1254-1257 (link)

    logic: list comprehension creates max_logit_per_step tensors only when return_max_logit=True and non-FP8; ensure all subsequent accesses to max_logit_per_step[i] guard on the same condition

  9. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1619-1623 (link)

    logic: computes max_logit via torch.clone and torch.maximum only when return_max_logit=True; confirm that max_logit_per_step is never None at these indices when the condition is met

  10. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 1629-1632 (link)

    logic: all-reduces max_logit only when return_max_logit=True; ensure that max_logit is not None at this point

  11. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2754 (link)

    logic: unpacking *max_logit_ from fused_attn_fwd requires consistent return tuple structure when return_max_logit=True. Does fused_attn_fwd guarantee that it returns a tuple with the max_logit element appended when return_max_logit=True, and omits it otherwise?

  12. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2776-2777 (link)

    logic: assigns max_logit_per_step[i] from max_logit_[0] only when return_max_logit=True; ensure max_logit_ is not empty

  13. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 2812-2821 (link)

    logic: clones and computes maximum of max_logit_per_step only when return_max_logit=True; verify that all max_logit_per_step entries are initialized tensors, not None

  14. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 3261 (link)

    logic: unpacking *max_logit from fused_attn_fwd requires consistent return tuple structure when return_max_logit=True

  15. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1146 (link)

    logic: max_logit initialized as None but only populated for non-FP8 path. FP8 path (lines 1174-1233) doesn't compute max_logit, leading to inconsistent behavior. Should FP8 path also return max_logit when requested, or is this intentionally unsupported?

  16. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1250 (link)

    logic: Unpacking *max_logit from fused_attn_fwd assumes optional return. If fused_attn_fwd doesn't return max_logit, this unpacking will fail. Does fused_attn_fwd always return max_logit as an optional value, or only when return_max_logit=True?

  17. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1936-1938 (link)

    style: Conditional return logic duplicated across UnfusedDotProductAttention and FusedAttention. Consider extracting to a helper function to reduce duplication.

49 files reviewed, 37 comments

Edit Code Review Agent Settings | Greptile

output->data.dtype, OType,

constexpr int nvec = 32 / sizeof(OType);
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Multiple statements on one line. Split initialization and assignment for clarity:

Suggested change
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
using namespace dispatch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: redundant using namespace dispatch inside the dispatch namespace

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

float (*DActOP)(float, const ParamOP &)>
void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
using namespace dispatch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: redundant using namespace dispatch inside the dispatch namespace

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Comment on lines 34 to 35
dispatch::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: indentation of the continuation line changed from 4 to 1 space, which deviates from the project's .clang-format (ContinuationIndentWidth: 4)

output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Comment should be FP4_TYPE_SUPPORTED to match the #if on line 32, not CUDA_VERSION

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

partial_dbias += elt;
}
} else {
// If no activation, elt is 0 so we can safely do this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: comment says 'elt is 0' but this is incorrect – elt is only 0 if out_of_bounds is true; for in-bounds elements elt equals the loaded value. The logic relies on conditional accumulation in the IS_DACT branch but assumes unconditional safety here. Is the assumption that out-of-bounds elements are loaded as zero guaranteed by the TMA copy operation?

amax = fmaxf(amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: same issue as line 168 – comment incorrectly states 'elt is 0' but elt is only guaranteed zero for out-of-bounds accesses. For in-bounds elements without activation, elt is the loaded input value.

* \brief Common functions in quantize.
*/

#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: header guard mismatch – file is cast/core/common.cuh but guard is QUANTIZE_CORE_COMMON_CUH_

************************************************************************/

/*! \file common.cuh
* \brief Common functions in quantize.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: file comment says "in quantize" but this lives in cast/core/ – update to "Common functions in cast." or similar

Comment on lines +77 to +78
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: template parameter IType name suggests input type, but it is used as OType (output) inside reduce_dbias_kernel; consider renaming to OType for clarity

Signed-off-by: Oleg Goncharov <[email protected]>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_cast_kernels_cleanup branch from dc04670 to 703556c Compare October 28, 2025 20:43
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest commit addresses formatting inconsistencies introduced during the previous code reorganization. Seven files received minor formatting adjustments: math.h condensed a one-line sigmoidf function to match the project's .clang-format rule allowing short functions on a single line; dispatch/quantize.cuh split a long boolean condition across two lines; dispatch/gated.cuh, gated_mxfp8.cuh, and core_nvfp4.cuh adjusted line breaks and continuation indentation in function signatures and error messages; quantize_transpose_nvfp4.cuh reformatted two PTX function calls; and gated_fp8.cuh received both formatting changes (line-continuation indentation) and a critical functional modification (removal of null-pointer validation checks for scale_inv and scale tensors, plus a macro change that broadens permitted output types beyond FP8). These changes ensure consistency with the project's 100-character column limit and continuation-indent rules as part of the broader refactoring to split monolithic cast kernel headers by scaling mode.

Important Files Changed

Filename Score Overview
transformer_engine/common/util/math.h 5/5 Reformatted sigmoidf from multi-line to single-line (pure formatting)
transformer_engine/common/cast/dispatch/quantize.cuh 5/5 Split long boolean condition across two lines (formatting only)
transformer_engine/common/cast/dispatch/gated.cuh 5/5 Reformatted error messages and function calls for readability (no logic changes)
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh 5/5 Reformatted function signatures to break parameters across lines (cosmetic only)
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh 4/5 Reformatted PTX calls but introduced 1-space continuation indent (violates .clang-format)
transformer_engine/common/cast/fp8/gated_fp8.cuh 2/5 Removed null-pointer checks for scale tensors and changed type-switch macro, risking null dereferences and incorrect behavior for non-FP8 outputs
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh 4/5 Reordered includes and introduced nested namespaces; removed trailing whitespace; preserved commented-out code for numerical matching

Confidence score: 2/5

  • This PR introduces critical safety regressions in gated_fp8.cuh that could cause null-pointer dereferences and incorrect scaling behavior for non-FP8 outputs, despite being labeled a "no functional changes" refactoring
  • Score reflects a dangerous removal of validation checks (lines 277–280 in gated_fp8.cuh) and a macro change that may silently break FP8-specific scaling logic for other data types; additional style violations in quantize_transpose_nvfp4.cuh (1-space vs. 4-space continuation indent) and include-order changes in core_nvfp4.cuh that could affect compilation
  • Pay extremely close attention to transformer_engine/common/cast/fp8/gated_fp8.cuh (removed null checks, macro change) and verify that all callers guarantee scale-tensor allocation before invoking cast_gated_tma; review quantize_transpose_nvfp4.cuh and core_nvfp4.cuh for style compliance

7 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p,
cudaStream_t stream) {
using namespace dispatch;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

float (*DActOP)(float, const ParamOP &)>
void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input,
NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) {
using namespace dispatch;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Comment on lines 37 to 45
if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient
input_tensor = convertNVTETensorCheck(grad);
activation_input_tensor = convertNVTETensor(input);
} else {
// forward = input is activation input
input_tensor = convertNVTETensorCheck(input);
activation_input_tensor = nullptr;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I really hoped that as part of this refactor we could fix this part of this API since it is super confusing and actually have consistent naming of the function arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the dispatch/quantize.cuh in a similar way as dispatch/gated.cuh, splitting into two helpers (FWD and BWD), there is a bunch of code duplicated, but now the logic of the code should be more intuitive.

output->data.dtype, OType,

constexpr int nvec = 32 / sizeof(OType);
DequantizeParam p; p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

Comment on lines +40 to +52
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it seems to be basically the same function written twice.

// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is wrong comment, it actually points the correct fact that something is skippable here - the global_amax == 0 check. This is because if amax is 0 then the scale is infinity and would be clamped by the fminf already (and if the amax is 0 we do not really care what the value of the scale is as long as it is finite so multiplied by 0 would not produce NaN). This kernel is so tiny though that it doesn't matter.

output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Comment on lines +38 to +40
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's maybe correct and I did intend to look into this kernel anyway since it doesn't work in some cases, but that is out of scope for this PR.

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The latest update addresses several issues raised in previous reviews: (1) dispatcher refactoring splits the monolithic quantize_helper into separate quantize_fwd_helper and quantize_bwd_helper functions, removing unused template parameters from forward paths; (2) parameter reordering in gated kernels moves gated_input/input before grad to establish a consistent convention across scaling modes (forward input, then gradient); (3) naming improvements rename helpers to quantize_gated_fwd_helper and quantize_gated_bwd_helper for clarity; (4) cosmetic fixes correct #endif comments, remove redundant using namespace dispatch declarations, and adjust formatting to match .clang-format conventions. All changes are code-movement or signature-reordering refactorings with no functional modifications—the dispatch logic, kernel launch sites, and validation checks remain identical to the previous version.

Important Files Changed

Filename Score Overview
transformer_engine/common/util/math.h 5/5 Collapsed sigmoidf function from three lines to one line for consistency with other short device functions
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh 5/5 Moved NVFP4 dequantization kernel to dedicated header; corrected #endif comment from CUDA_VERSION to FP4_TYPE_SUPPORTED
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh 5/5 Reordered two #include directives and removed trailing whitespace; no logic changes
transformer_engine/common/activation/activation_template.h 5/5 Updated activation function templates to call renamed _fwd_helper and _bwd_helper dispatch functions with simplified template parameters
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh 5/5 Swapped quantize_gated parameter order (gated_input before grad) to match other scaling modes; minor formatting adjustments
transformer_engine/common/cast/fp8/gated_fp8.cuh 5/5 Reordered function parameters in cast_gated_tma and cast_gated_bwd to place input before gradient for consistency
transformer_engine/common/cast/dispatch/gated.cuh 4/5 Renamed helpers to _fwd_helper and _bwd_helper; reordered arguments in fp8::cast_gated_tma and mxfp8::quantize_gated calls to pass input/gated_input before grad
transformer_engine/common/cast/cast.cu 3.5/5 Refactored quantization API entry points to use separate quantize_fwd_helper and quantize_bwd_helper dispatchers; reordered parameters (input, activation_input, output, ...) for backward helpers
transformer_engine/common/cast/dispatch/quantize.cuh 0/5 File listed but no summary provided; changes unknown

Confidence score: 3.5/5

  • This PR is mostly safe to merge but requires careful verification that parameter reordering matches updated kernel signatures across all call sites.
  • Score lowered because (1) the parameter reordering in cast.cu for backward helpers (input, activation_input, output) may break call sites if upstream callers were not updated, (2) the argument-order change in gated kernel calls (fp8::cast_gated_tma, mxfp8::quantize_gated) must be validated against the kernel signatures to ensure input/gated_input is now expected before grad, and (3) file transformer_engine/common/cast/dispatch/quantize.cuh has no summary provided, so its changes are unreviewed.
  • Pay close attention to transformer_engine/common/cast/cast.cu (verify all callers of nvte_quantize_dbias* pass parameters in the new order) and transformer_engine/common/cast/dispatch/gated.cuh (confirm fp8::cast_gated_tma and mxfp8::quantize_gated kernel signatures accept input/gated_input before grad).

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +61 to +63
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: dummy_input_tensor, dummy_dbias_tensor, and dummy_workspace_tensor are declared but never used in the FP8 forward path. Consider removing them or documenting why they're passed as null. Are these dummy pointers required by the kernel signature, or can the kernel signature be simplified for the forward case?

Comment on lines +82 to +84
const Tensor *dummy_input_tensor = nullptr;
Tensor *dummy_dbias_tensor = nullptr;
Tensor *dummy_workspace_tensor = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: duplicate dummy variable declarations (same as lines 61–63). Consider factoring these into a shared helper or refactoring the kernel signature to avoid passing unused parameters

Comment on lines 156 to 157
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], nullptr, detail::get_compute_stream(i % num_streams));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: In multi-tensor quantize, quant_configs parameter is declared on line 138 but never used. Should each call to quantize_fwd_helper pass quant_configs (or an indexed config)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should pass the quant_config (quant_configs is actually misleading name since it suggests there are multiple of those) to those quantize functions.

@Oleg-Goncharov
Copy link
Collaborator Author

Regarding the compute_decoding_scaling_factor function used in NVFP4, I intentionally left two previously existing versions of the scaling factor computation (one is used in the more performant kernel which quantizes in the rowwise only; another is used in the quantize_transpose variant), to not introduce any changes that can impact numerics. But it definitely needs to be refactored in a separate PR.

Comment on lines 156 to 157
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(
inputs[i], outputs[i], nullptr, detail::get_compute_stream(i % num_streams));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should pass the quant_config (quant_configs is actually misleading name since it suggests there are multiple of those) to those quantize functions.

input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move those functions involving activations to the activation-specific files? That way we could make sure that we use fast math only for the activations (and then maybe actually turn it on by default?) and not for the entire cast.cu file.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR restructures cast and gated kernel code by splitting the monolithic cast_kernels.cuh and cast_gated_kernels.cuh files into a cleaner directory hierarchy organized by scaling mode (FP8, MXFP8, NVFP4). The refactoring introduces a dispatch layer that routes operations based on scaling_mode, with format-specific implementations in dedicated subdirectories.

Key Changes:

  • Moved util/cast.cu to cast/cast.cu and created new directory structure: cast/{core,dispatch,fp8,mxfp8,nvfp4}/
  • Deleted 2188-line cast_kernels.cuh, replacing it with 10+ focused header files
  • Added dispatch layer (dispatch/*.cuh) to route operations by scaling mode
  • Extracted common utilities to core/common.cuh
  • Updated CMakeLists.txt and include paths accordingly

Issues Found:

  • Critical: Removed null-pointer validation for scale_inv.dptr and scale.dptr in FP8 gated kernels (previously validated in original code)
  • Critical: Changed type switch macro from FP8ONLY to generic OUTPUT in gated kernels, potentially allowing non-FP8 output types where FP8-specific scaling logic is expected
  • Missing bounds check in NVFP4 dequantization could cause out-of-bounds memory access
  • Unused quant_configs parameter in multi-tensor quantize path
  • Several misleading comments and minor style issues

Confidence Score: 3/5

  • This PR has moderate risk due to removed validation logic and type safety changes in FP8 gated kernels that could cause runtime errors
  • Score reflects that while the refactoring structure is sound and most code is cleanly extracted, there are two critical issues in the FP8 gated kernel path: (1) removed null-pointer checks for scale tensors that could cause segfaults, and (2) broadened type constraints that may allow incorrect types through. The NVFP4 bounds check issue is also concerning. These are functional changes disguised as pure refactoring.
  • transformer_engine/common/cast/fp8/gated_fp8.cuh requires immediate attention for restored validation and type safety. transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh needs bounds checking.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/cast/cast.cu 4/5 Moved from util/cast.cu. Updated include paths to new cast/ directory structure. Entry point functions unchanged except for delegation to new dispatch layer.
transformer_engine/common/cast/dispatch/quantize.cuh 3/5 New dispatcher for quantization operations. Routes to FP8/MXFP8/NVFP4 implementations based on scaling mode. Contains unused quant_configs parameter in multi-tensor path (line 138).
transformer_engine/common/cast/fp8/quantize_fp8.cuh 3/5 Extracted FP8 quantization kernels from original cast_kernels.cuh. Contains misleading comments about out-of-bounds element handling (lines 168, 178).
transformer_engine/common/cast/fp8/gated_fp8.cuh 2/5 Extracted FP8 gated activation kernels. Removed null-pointer validation for scale_inv.dptr and scale.dptr that existed in original code. Changed macro from FP8ONLY to generic OUTPUT, potentially allowing non-FP8 types.
transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh 3/5 Extracted NVFP4 quantization kernels. Contains potential boundary check issue at line 128 comparing against wrong dimension.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh 2/5 Extracted NVFP4 dequantization kernels. Missing bounds check on thread_idx could cause out-of-bounds memory access (line 38). Incorrect comment reference to CUDA_VERSION instead of FP4_TYPE_SUPPORTED.

Sequence Diagram

sequenceDiagram
    participant API as nvte_quantize APIs<br/>(cast.cu)
    participant Dispatch as Dispatch Layer<br/>(dispatch/*.cuh)
    participant FP8 as FP8 Kernels<br/>(fp8/*.cuh)
    participant MXFP8 as MXFP8 Kernels<br/>(mxfp8/*.cuh)
    participant NVFP4 as NVFP4 Kernels<br/>(nvfp4/*.cuh)
    participant Core as Core Utilities<br/>(core/common.cuh)

    Note over API,Core: Refactoring: Split by Scaling Mode

    API->>Dispatch: quantize_fwd_helper()
    Dispatch->>Dispatch: Check output->scaling_mode
    
    alt DELAYED_TENSOR_SCALING
        Dispatch->>FP8: quantize<FP8>()
        FP8->>Core: Use common helpers
        FP8-->>Dispatch: FP8 quantized output
    else MXFP8_1D_SCALING
        Dispatch->>MXFP8: quantize<MXFP8>()
        MXFP8->>Core: Use common helpers
        MXFP8-->>Dispatch: MXFP8 quantized output
    else NVFP4_1D_SCALING
        Dispatch->>NVFP4: quantize_transpose()
        NVFP4->>Core: Use common helpers
        NVFP4-->>Dispatch: NVFP4 quantized output
    end
    
    Dispatch-->>API: Return quantized tensor
    
    Note over API,NVFP4: Gated operations follow similar pattern
    API->>Dispatch: quantize_gated_fwd_helper()
    Dispatch->>Dispatch: Check scaling_mode
    
    alt DELAYED_TENSOR_SCALING
        Dispatch->>FP8: cast_gated_fwd()
    else MXFP8_1D_SCALING
        Dispatch->>MXFP8: quantize_gated()
    end
Loading

20 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR refactors the cast and activation kernel code for better organization and maintainability:

  • Activation Functions Extracted: Moved GELU, ReLU, and SwiGLU activation functions from cast.cu into separate dedicated files (gelu.cu, relu.cu, swiglu.cu) under common/activation/
  • Kernel Organization by Scaling Mode: Split the large monolithic cast_kernels.cuh and cast_gated_kernels.cuh into smaller, focused headers organized by scaling mode (FP8, MXFP8, NVFP4) under common/cast/ subdirectories
  • Dispatch Layer: Created new dispatch layer (dispatch/quantize.cuh, dispatch/dequantize.cuh, dispatch/gated.cuh) that routes to appropriate scaling-mode-specific implementations
  • Fast-Math Option: Added NVTE_BUILD_ACTIVATION_WITH_FAST_MATH CMake option (defaults to ON) to enable --use_fast_math compilation for activation kernels

The refactoring is purely structural - no functional changes to kernel logic or behavior. The new organization improves navigability and makes it easier to extend specific scaling paths in the future.

Confidence Score: 5/5

  • This PR is safe to merge - it is a pure code refactoring with no functional changes
  • This is a well-executed refactoring that only reorganizes existing code without modifying logic. The activation functions are simply moved to separate files with identical implementations using template dispatchers. The cast kernels are split by scaling mode into logical subdirectories. All previous comments from reviewers focus on pre-existing issues in the moved code (style, potential bugs), not issues introduced by this refactoring. The build configuration properly handles the new file structure and adds an opt-in fast-math flag for activations.
  • No files require special attention - this is a straightforward refactoring

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/CMakeLists.txt 5/5 Added activation files (gelu.cu, relu.cu, swiglu.cu) to arch-specific sources list and enabled optional fast-math compilation for activation kernels via NVTE_BUILD_ACTIVATION_WITH_FAST_MATH flag (defaults to ON)
transformer_engine/common/activation/gelu.cu 5/5 New file containing GELU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers
transformer_engine/common/activation/relu.cu 5/5 New file containing ReLU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers
transformer_engine/common/activation/swiglu.cu 5/5 New file containing SwiGLU-related activation functions extracted from cast.cu - no logic changes, just code reorganization using activation_template.h dispatchers
transformer_engine/common/cast/cast.cu 5/5 Simplified to only contain cast/quantize/dequantize API entry points - all activation functions moved to separate files, all kernel implementations moved to organized subdirectories by scaling mode

Sequence Diagram

sequenceDiagram
    participant API as Public API<br/>(cast.cu, activation/*.cu)
    participant Dispatch as Dispatch Layer<br/>(dispatch/*.cuh)
    participant FP8 as FP8 Kernels<br/>(fp8/*.cuh)
    participant MXFP8 as MXFP8 Kernels<br/>(mxfp8/*.cuh)
    participant NVFP4 as NVFP4 Kernels<br/>(nvfp4/*.cuh)
    
    Note over API,NVFP4: Before: Monolithic cast_kernels.cuh with all scaling modes mixed
    
    API->>Dispatch: nvte_quantize() / nvte_gelu() etc.
    Note over Dispatch: Determines scaling mode<br/>from tensor metadata
    
    alt FP8 / Delayed Scaling
        Dispatch->>FP8: quantize_fp8() / gated_fp8()
        FP8-->>Dispatch: Result
    else MXFP8 Scaling
        Dispatch->>MXFP8: quantize_mxfp8() / gated_mxfp8()
        MXFP8-->>Dispatch: Result
    else NVFP4 Scaling
        Dispatch->>NVFP4: quantize_nvfp4() / quantize_transpose_nvfp4()
        NVFP4-->>Dispatch: Result
    end
    
    Dispatch-->>API: Return quantized tensor
    
    Note over API,NVFP4: After: Organized by scaling mode in subdirectories
Loading

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR refactors large monolithic kernel files (cast_kernels.cuh - 2188 lines, cast_gated_kernels.cuh) into a well-organized directory structure split by scaling mode and operation type.

Key Changes:

  • Moved util/cast.cucast/cast.cu and split util/cast_kernels.cuh into modular headers
  • Created new cast/dispatch/ layer with quantize.cuh, gated.cuh, dequantize.cuh dispatchers
  • Organized kernels by precision format: fp8/, mxfp8/, nvfp4/ subdirectories
  • Extracted shared utilities to cast/core/common.cuh
  • Updated activation_template.h to use new dispatcher API (simpler function calls)
  • Removed cast.cu from fast-math compilation list in CMakeLists

Benefits:

  • Improved navigability: each scaling mode (FP8, MXFP8, NVFP4, block scaling) now has dedicated files
  • Better maintainability: changes to one format won't affect others
  • Cleaner separation of concerns: dispatch logic separated from kernel implementations
  • No functional changes: existing tests pass, behavior unchanged

Confidence Score: 5/5

  • This PR is safe to merge - it's a pure code organization refactoring with no functional changes
  • This is a well-executed refactoring that moves code without modifying behavior. The large monolithic files (2188+ lines) are cleanly split by scaling mode (FP8, MXFP8, NVFP4) into logical subdirectories. All code is moved, not modified, minimizing risk. The new dispatcher layer provides clean separation of concerns. Build configuration updated appropriately (removed cast.cu from fast-math list). Previous code review comments on style/logic issues remain valid but are pre-existing, not introduced by this PR.
  • No files require special attention - this is a straightforward refactoring

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/cast/dispatch/quantize.cuh 4/5 New dispatcher for quantize operations, splits logic by scaling mode (FP8, MXFP8, NVFP4, block scaling)
transformer_engine/common/cast/dispatch/gated.cuh 5/5 New dispatcher for gated activation operations, cleanly separates FWD/BWD helpers by scaling mode
transformer_engine/common/cast/fp8/quantize_fp8.cuh 5/5 FP8 quantization kernels extracted from cast_kernels.cuh, code moved without modification
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh 5/5 MXFP8 quantization kernels extracted from cast_kernels.cuh, organized by scaling mode
transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh 5/5 NVFP4 quantization kernels extracted from cast_kernels.cuh, clean separation by format
transformer_engine/common/cast/core/common.cuh 5/5 Common cast utilities (reduce_dbias_kernel, helpers) extracted for shared use across formats
transformer_engine/common/CMakeLists.txt 5/5 Updated path from util/cast.cu to cast/cast.cu, removed cast.cu from fast-math compilation list
transformer_engine/common/activation/activation_template.h 5/5 Updated includes to use new dispatcher headers, simplified function calls to new helpers

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant Act as activation_template.h
    participant Disp as cast/dispatch/
    participant FP8 as cast/fp8/
    participant MXFP8 as cast/mxfp8/
    participant NVFP4 as cast/nvfp4/
    participant Core as cast/core/

    Note over App,Core: New Architecture: Organized by Scaling Mode

    App->>Act: act_fn() / dact_fn()
    Act->>Disp: quantize_fwd_helper() / quantize_bwd_helper()
    
    alt NVTE_DELAYED_TENSOR_SCALING
        Disp->>FP8: quantize() / dequantize()
        FP8->>Core: reduce_dbias_kernel()
    else NVTE_MXFP8_1D_SCALING
        Disp->>MXFP8: quantize() / dequantize()
        MXFP8->>Core: reduce_dbias_kernel()
    else NVTE_NVFP4_1D_SCALING
        Disp->>NVFP4: quantize_transpose()
        NVFP4->>NVFP4: core_nvfp4.cuh helpers
    else NVTE_BLOCK_SCALING_*
        Disp->>Disp: quantize_transpose_*_blockwise()
    end

    App->>Act: gated_act_fn() / dgated_act_fn()
    Act->>Disp: quantize_gated_fwd/bwd_helper()
    
    alt NVTE_DELAYED_TENSOR_SCALING
        Disp->>FP8: cast_gated_tma() / cast_gated_fwd()
    else NVTE_MXFP8_1D_SCALING
        Disp->>MXFP8: quantize_gated()
    end

    Note over Disp,NVFP4: Old: 2188 lines in cast_kernels.cuh<br/>New: Split into fp8/, mxfp8/, nvfp4/, core/
Loading

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx merged commit 0e80c84 into NVIDIA:main Oct 30, 2025
41 checks passed
@greptile-apps greptile-apps bot mentioned this pull request Oct 30, 2025
13 tasks
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request Nov 4, 2025
* Separated gated and dequantize kernels

Signed-off-by: Oleg Goncharov <[email protected]>

* Separated quantize, dequantize and gated functions

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed lint issues

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed persistent lint issues

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added missing compute capability 10.0 check for Quantize FP8 TMA kernels

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed the issue which was added again by autofix

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite

Signed-off-by: Oleg Goncharov <[email protected]>

* Removed unsupported template arguments in NVFP4 quantize

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed undefined symbol error

Signed-off-by: Oleg Goncharov <[email protected]>

* Fixed condition

Signed-off-by: Oleg Goncharov <[email protected]>

* Fixed CUDA version check

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed arch conditions order

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: Oleg Goncharov <[email protected]>

* Clean up

Signed-off-by: Oleg Goncharov <[email protected]>

* Small fix

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Small fix

Signed-off-by: Oleg Goncharov <[email protected]>

* Fixes per the PR review

Signed-off-by: Oleg Goncharov <[email protected]>

* Fix

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Split quantize helper into two (FWD and BWD) functions

Signed-off-by: Oleg Goncharov <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list

Signed-off-by: Oleg Goncharov <[email protected]>

* Enabled fast math for activations by default

Signed-off-by: Oleg Goncharov <[email protected]>

* Disabled fast math for activations by default

Signed-off-by: Oleg Goncharov <[email protected]>

---------

Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants